{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Adding progress bars to Learner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from exp.nb_09b import *\n",
    "import time\n",
    "from fastprogress.fastprogress import master_bar, progress_bar\n",
    "from fastprogress.fastprogress import format_time"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "One thing has been missing all this time, and as fun as it is to stare at a blank screen waiting for the results, it's nicer to have some tool to track progress."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imagenette data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Jump_to lesson 11 video](https://course19.fast.ai/videos/?lesson=11&t=6743)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = datasets.untar_data(datasets.URLs.IMAGENETTE_160)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tfms = [make_rgb, ResizeFixed(128), to_byte_tensor, to_float_tensor]\n",
    "bs = 64\n",
    "\n",
    "il = ImageList.from_files(path, tfms=tfms)\n",
    "sd = SplitData.split_by_func(il, partial(grandparent_splitter, valid_name='val'))\n",
    "ll = label_by_func(sd, parent_labeler, proc_y=CategoryProcessor())\n",
    "data = ll.to_databunch(bs, c_in=3, c_out=10, num_workers=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nfs = [32]*4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We rewrite the `AvgStatsCallback` to add a line with the names of the things measured and keep track of the time per epoch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export \n",
    "class AvgStatsCallback(Callback):\n",
    "    def __init__(self, metrics):\n",
    "        self.train_stats,self.valid_stats = AvgStats(metrics,True),AvgStats(metrics,False)\n",
    "    \n",
    "    def begin_fit(self):\n",
    "        met_names = ['loss'] + [m.__name__ for m in self.train_stats.metrics]\n",
    "        names = ['epoch'] + [f'train_{n}' for n in met_names] + [\n",
    "            f'valid_{n}' for n in met_names] + ['time']\n",
    "        self.logger(names)\n",
    "    \n",
    "    def begin_epoch(self):\n",
    "        self.train_stats.reset()\n",
    "        self.valid_stats.reset()\n",
    "        self.start_time = time.time()\n",
    "        \n",
    "    def after_loss(self):\n",
    "        stats = self.train_stats if self.in_train else self.valid_stats\n",
    "        with torch.no_grad(): stats.accumulate(self.run)\n",
    "    \n",
    "    def after_epoch(self):\n",
    "        stats = [str(self.epoch)] \n",
    "        for o in [self.train_stats, self.valid_stats]:\n",
    "            stats += [f'{v:.6f}' for v in o.avg_stats] \n",
    "        stats += [format_time(time.time() - self.start_time)]\n",
    "        self.logger(stats)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we add the progress bars... with a Callback of course! `master_bar` handles the count over the epochs while its child `progress_bar` is looping over all the batches. We just create one at the beginning or each epoch/validation phase, and update it at the end of each batch. By changing the logger of the `Learner` to the `write` function of the master bar, everything is automatically written there.\n",
    "\n",
    "Note: this requires fastprogress v0.1.21 or later. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export \n",
    "class ProgressCallback(Callback):\n",
    "    _order=-1\n",
    "    def begin_fit(self):\n",
    "        self.mbar = master_bar(range(self.epochs))\n",
    "        self.mbar.on_iter_begin()\n",
    "        self.run.logger = partial(self.mbar.write, table=True)\n",
    "        \n",
    "    def after_fit(self): self.mbar.on_iter_end()\n",
    "    def after_batch(self): self.pb.update(self.iter)\n",
    "    def begin_epoch   (self): self.set_pb()\n",
    "    def begin_validate(self): self.set_pb()\n",
    "        \n",
    "    def set_pb(self):\n",
    "        self.pb = progress_bar(self.dl, parent=self.mbar)\n",
    "        self.mbar.update(self.epoch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By making the progress bar a callback, you can easily choose if you want to have them shown or not."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cbfs = [partial(AvgStatsCallback,accuracy),\n",
    "        CudaCallback,\n",
    "        ProgressCallback,\n",
    "        partial(BatchTransformXCallback, norm_imagenette)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = get_learner(nfs, data, 0.4, conv_layer, cb_funcs=cbfs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "        <style>\n",
       "            /* Turns off some styling */\n",
       "            progress {\n",
       "                /* gets rid of default border in Firefox and Opera. */\n",
       "                border: none;\n",
       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "                background-size: auto;\n",
       "            }\n",
       "            .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "                background: #F44336;\n",
       "            }\n",
       "        </style>\n",
       "      <progress value='0' class='' max='2' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      \n",
       "    </div>\n",
       "    \n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>train_accuracy</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>valid_accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>1.839106</td>\n",
       "      <td>0.356849</td>\n",
       "      <td>1.940670</td>\n",
       "      <td>0.364076</td>\n",
       "      <td>00:05</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>1.500339</td>\n",
       "      <td>0.490865</td>\n",
       "      <td>1.692970</td>\n",
       "      <td>0.423439</td>\n",
       "      <td>00:05</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 09c_add_progress_bar.ipynb to exp/nb_09c.py\r\n"
     ]
    }
   ],
   "source": [
    "!./notebook2script.py 09c_add_progress_bar.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
